{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "# Multi-worker training with Keras\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_keras.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_keras.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/multi_worker_with_keras.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This tutorial demonstrates multi-worker distributed training with Keras model using `tf.distribute.Strategy` API, specifically `tf.distribute.MultiWorkerMirroredStrategy`. With the help of this strategy, a Keras model that was designed to run on single-worker can seamlessly work on multiple workers with minimal code change.\n",
        "\n",
        "[Distributed Training in TensorFlow](../../guide/distributed_training.ipynb) guide is available for an overview of the distribution strategies TensorFlow supports for those interested in a deeper understanding of `tf.distribute.Strategy` APIs.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## Setup\n",
        "\n",
        "First, some necessary imports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bnYxvfLD-LW-"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "import os\n",
        "import sys"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zz0EY91y3mxy"
      },
      "source": [
        "Before importing TensorFlow, make a few changes to the environment.\n",
        "\n",
        "Disable all GPUs. This prevents errors caused by the workers all trying to use the same GPU. For a real application each worker would be on a different machine."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "685pbYEY3jGC"
      },
      "outputs": [],
      "source": [
        "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7X1MS6385BWi"
      },
      "source": [
        "Reset the `TF_CONFIG` environment variable, you'll see more about this later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WEJLYa2_7OZF"
      },
      "outputs": [],
      "source": [
        "os.environ.pop('TF_CONFIG', None)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rd4L9Ii77SS8"
      },
      "source": [
        "Be sure that the current directory is on python's path. This allows the notebook to import the files written by `%%writefile` later.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hPBuZUNSZmrQ"
      },
      "outputs": [],
      "source": [
        "if '.' not in sys.path:\n",
        "  sys.path.insert(0, '.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pDhHuMjb7bfU"
      },
      "source": [
        "Now import TensorFlow."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vHNvttzV43sA"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0S2jpf6Sx50i"
      },
      "source": [
        "### Dataset and model definition"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fLW6D2TzvC-4"
      },
      "source": [
        "Next create an `mnist.py` file with a simple model and dataset setup. This python file will be used by the worker-processes in this tutorial:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dma_wUAxZqo2"
      },
      "outputs": [],
      "source": [
        "%%writefile mnist.py\n",
        "\n",
        "import os\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "\n",
        "def mnist_dataset(batch_size):\n",
        "  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()\n",
        "  # The `x` arrays are in uint8 and have values in the range [0, 255].\n",
        "  # You need to convert them to float32 with values in the range [0, 1]\n",
        "  x_train = x_train / np.float32(255)\n",
        "  y_train = y_train.astype(np.int64)\n",
        "  train_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)\n",
        "  return train_dataset\n",
        "\n",
        "def build_and_compile_cnn_model():\n",
        "  model = tf.keras.Sequential([\n",
        "      tf.keras.Input(shape=(28, 28)),\n",
        "      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
        "      tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(128, activation='relu'),\n",
        "      tf.keras.layers.Dense(10)\n",
        "  ])\n",
        "  model.compile(\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
        "      metrics=['accuracy'])\n",
        "  return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2UL3kisMO90X"
      },
      "source": [
        "Try training the model for a small number of epochs and observe the results of a single worker to make sure everything works correctly. As training progresses, the loss should drop and the accuracy should increase."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Qe6iAf5O8iJ"
      },
      "outputs": [],
      "source": [
        "import mnist\n",
        "\n",
        "batch_size = 64\n",
        "single_worker_dataset = mnist.mnist_dataset(batch_size)\n",
        "single_worker_model = mnist.build_and_compile_cnn_model()\n",
        "single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JmgZwwymxqt5"
      },
      "source": [
        "## Multi-worker Configuration\n",
        "\n",
        "Now let's enter the world of multi-worker training. In TensorFlow, the `TF_CONFIG` environment variable is required for training on multiple machines, each of which possibly has a different role. `TF_CONFIG` is a JSON string used to specify the cluster configuration on each worker that is part of the cluster.\n",
        "\n",
        "Here is an example configuration:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XK1eTYvSZiX7"
      },
      "outputs": [],
      "source": [
        "tf_config = {\n",
        "    'cluster': {\n",
        "        'worker': ['localhost:12345', 'localhost:23456']\n",
        "    },\n",
        "    'task': {'type': 'worker', 'index': 0}\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JjgwJbPKZkJL"
      },
      "source": [
        "Here is the same `TF_CONFIG` serialized as a JSON string:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yY-T0YDQZjbu"
      },
      "outputs": [],
      "source": [
        "json.dumps(tf_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AUBmYRZqxthH"
      },
      "source": [
        "There are two components of `TF_CONFIG`: `cluster` and `task`.\n",
        "\n",
        "* `cluster` is the same for all workers and provides information about the training cluster, which is a dict consisting of different types of jobs such as `worker`. In multi-worker training with `MultiWorkerMirroredStrategy`, there is usually one `worker` that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regular `worker` does. Such a worker is referred to as the `chief` worker, and it is customary that the `worker` with `index` 0 is appointed as the chief `worker` (in fact this is how `tf.distribute.Strategy` is implemented).\n",
        "\n",
        "* `task` provides information of the current task and is different on each worker. It specifies the `type` and `index` of that worker. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8YFpxrcsZ2xG"
      },
      "source": [
        "In this example, you set the task `type` to `\"worker\"` and the task `index` to `0`. This machine is the first worker and will be appointed as the chief worker and do more work than the others. Note that other machines will need to have the `TF_CONFIG` environment variable set as well, and it should have the same `cluster` dict, but different task `type` or task `index` depending on what the roles of those machines are.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aogb74kHxynz"
      },
      "source": [
        "For illustration purposes, this tutorial shows how one may set a `TF_CONFIG` with 2 workers on `localhost`.  In practice, users would create multiple workers on external IP addresses/ports, and set `TF_CONFIG` on each worker appropriately.\n",
        "\n",
        "In this example you will use 2 workers, the first worker's `TF_CONFIG` is shown above. For the second worker you would set `tf_config['task']['index']=1`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f83FVYqDX3aX"
      },
      "source": [
        "Above, `tf_config` is just a local variable in python. To actually use it to configure training, this dictionary needs to be serialized as JSON, and placed in the `TF_CONFIG` environment variable."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cIlkfWmjz1PG"
      },
      "source": [
        "### Environment variables and subprocesses in notebooks"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FcjAbuGY1ACJ"
      },
      "source": [
        "Subprocesses inherit environment variables from their parent. So if you set an environment variable in this `jupyter notebook` process:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PH2gHn2_0_U8"
      },
      "outputs": [],
      "source": [
        "os.environ['GREETINGS'] = 'Hello TensorFlow!'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gQkIX-cg18md"
      },
      "source": [
        "You can access the environment variable from a subprocesses:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pquKO6IA18G5"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "echo ${GREETINGS}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "af6BCA-Y2fpz"
      },
      "source": [
        "In the next section, you'll use this to pass the `TF_CONFIG` to the worker subprocesses. You would never really launch your jobs this way, but it's sufficient for the purposes of this tutorial: To demonstrate a minimal multi-worker example."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UhNtHfuxCGVy"
      },
      "source": [
        "## Choose the right strategy\n",
        "\n",
        "In TensorFlow there are two main forms of distributed training:\n",
        "\n",
        "* Synchronous training, where the steps of training are synced across the workers and replicas, and\n",
        "* Asynchronous training, where the training steps are not strictly synced.\n",
        "\n",
        "`MultiWorkerMirroredStrategy`, which is the recommended strategy for synchronous multi-worker training, will be demonstrated in this guide.\n",
        "To train the model, use an instance of `tf.distribute.MultiWorkerMirroredStrategy`.\n",
        "\n",
        "`MultiWorkerMirroredStrategy` creates copies of all variables in the model's layers on each device across all workers.  It uses `CollectiveOps`, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync.  The [`tf.distribute.Strategy` guide](../../guide/distributed_training.ipynb) has more details about this strategy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1uFSHCJXMrQ-"
      },
      "outputs": [],
      "source": [
        "strategy = tf.distribute.MultiWorkerMirroredStrategy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N0iv7SyyAohc"
      },
      "source": [
        "Note: `TF_CONFIG` is parsed and TensorFlow's GRPC servers are started at the time `MultiWorkerMirroredStrategy()` is called, so the `TF_CONFIG` environment variable must be set before a `tf.distribute.Strategy` instance is created. Since `TF_CONFIG` is not set yet the above strategy is effectively single-worker training."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FMy2VM4Akzpr"
      },
      "source": [
        "`MultiWorkerMirroredStrategy` provides multiple implementations via the [`CommunicationOptions`](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CommunicationOptions) parameter.  `RING` implements ring-based collectives using gRPC as the cross-host communication layer.  `NCCL` uses [Nvidia's NCCL](https://developer.nvidia.com/nccl) to implement collectives.  `AUTO` defers the choice to the runtime.  The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H47DDcOgfzm7"
      },
      "source": [
        "## Train the model\n",
        "\n",
        "With the integration of `tf.distribute.Strategy` API into `tf.keras`, the only change you will make to distribute the training to multiple-workers is enclosing the model building and `model.compile()` call inside `strategy.scope()`. The distribution strategy's scope dictates how and where the variables are created, and in the case of `MultiWorkerMirroredStrategy`, the variables created are `MirroredVariable`s, and they are replicated on each of the workers.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wo6b9wX65glL"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  # Model building/compiling need to be within `strategy.scope()`.\n",
        "  multi_worker_model = mnist.build_and_compile_cnn_model()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mhq3fzyR5hTw"
      },
      "source": [
        "Note: Currently there is a limitation in `MultiWorkerMirroredStrategy` where TensorFlow ops need to be created after the instance of strategy is created. If you see `RuntimeError: Collective ops must be configured at program startup`, try creating the instance of `MultiWorkerMirroredStrategy` at the beginning of the program and put the code that may create ops after the strategy is instantiated."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jfYpmIxO6Jck"
      },
      "source": [
        "To actually run with `MultiWorkerMirroredStrategy` you'll need to run worker processes and pass a `TF_CONFIG` to them.\n",
        "\n",
        "Like the `mnist.py` file written earlier, here is the `main.py` that each of the workers will run:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BcsuBYrpgnlS"
      },
      "outputs": [],
      "source": [
        "%%writefile main.py\n",
        "\n",
        "import os\n",
        "import json\n",
        "\n",
        "import tensorflow as tf\n",
        "import mnist\n",
        "\n",
        "per_worker_batch_size = 64\n",
        "tf_config = json.loads(os.environ['TF_CONFIG'])\n",
        "num_workers = len(tf_config['cluster']['worker'])\n",
        "\n",
        "strategy = tf.distribute.MultiWorkerMirroredStrategy()\n",
        "\n",
        "global_batch_size = per_worker_batch_size * num_workers\n",
        "multi_worker_dataset = mnist.mnist_dataset(global_batch_size)\n",
        "\n",
        "with strategy.scope():\n",
        "  # Model building/compiling need to be within `strategy.scope()`.\n",
        "  multi_worker_model = mnist.build_and_compile_cnn_model()\n",
        "\n",
        "\n",
        "multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Aom9xelvJQ_6"
      },
      "source": [
        "In the code snippet above note that the `global_batch_size`, which gets passed to `Dataset.batch`, is set to `per_worker_batch_size * num_workers`. This ensures that each worker processes batches of `per_worker_batch_size` examples regardless of the number of workers."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lHLhOii67Saa"
      },
      "source": [
        "The current directory now contains both Python files:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bi6x05Sr60O9"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "ls *.py"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qmEEStPS6vR_"
      },
      "source": [
        "So json-serialize the `TF_CONFIG` and add it to the environment variables:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9uu3g7vV7Bbt"
      },
      "outputs": [],
      "source": [
        "os.environ['TF_CONFIG'] = json.dumps(tf_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MsY3dQLK7jdf"
      },
      "source": [
        "Now, you can launch a worker process that will run the `main.py` and use the `TF_CONFIG`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "txMXaq8d8N_S"
      },
      "outputs": [],
      "source": [
        "# first kill any previous runs\n",
        "%killbgscripts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qnSma_Ck7r-r"
      },
      "outputs": [],
      "source": [
        "%%bash --bg\n",
        "python main.py &> job_0.log"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZChyazqS7v0P"
      },
      "source": [
        "There are a few things to note about the above command:\n",
        "\n",
        "1. It uses the `%%bash` which is a [notebook \"magic\"](https://ipython.readthedocs.io/en/stable/interactive/magics.html) to run some bash commands.\n",
        "2. It uses the `--bg` flag to run the `bash` process in the background, because this worker will not terminate. It waits for all the workers before it starts.\n",
        "\n",
        "The backgrounded worker process won't print output to this notebook, so the `&>` redirects its output to a file, so you can see what happened.\n",
        "\n",
        "So, wait a few seconds for the process to start up:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hm2yrULE9281"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "time.sleep(10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZFPoNxg_9_Mx"
      },
      "source": [
        "Now look what's been output to the worker's logfile so far:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vZEOuVgQ9-hn"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "cat job_0.log"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RqZhVF7L_KOy"
      },
      "source": [
        "The last line of the log file should say: `Started server with target: grpc://localhost:12345`. The first worker is now ready, and is waiting for all the other worker(s) to be ready to proceed."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pi8vPNNA_l4a"
      },
      "source": [
        "So update the `tf_config` for the second worker's process to pick up:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lAiYkkPu_Jqd"
      },
      "outputs": [],
      "source": [
        "tf_config['task']['index'] = 1\n",
        "os.environ['TF_CONFIG'] = json.dumps(tf_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0AshGVO0_x0w"
      },
      "source": [
        "Now launch the second worker. This will start the training since all the workers are active (so there's no need to background this process):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_ESVtyQ9_xjx"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "python main.py"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hX4FA2O2AuAn"
      },
      "source": [
        "Now if you recheck the logs written by the first worker you'll see that it participated in training that model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rc6hw3yTBKXX"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "cat job_0.log"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zL79ak5PMzEg"
      },
      "source": [
        "Unsurprisingly this ran _slower_ than the the test run at the beginning of this tutorial. Running multiple workers on a single machine only adds overhead. The goal here was not to improve the training time, but only to give an example of multi-worker training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sG5_1UgrgniF"
      },
      "outputs": [],
      "source": [
        "# Delete the `TF_CONFIG`, and kill any background tasks so they don't affect the next section.\n",
        "os.environ.pop('TF_CONFIG', None)\n",
        "%killbgscripts"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9j2FJVHoUIrE"
      },
      "source": [
        "## Multi worker training in depth\n",
        "\n",
        "So far this tutorial has demonstrated a basic multi-worker setup. The rest of this document looks in detail other factors which may be useful or important for real use cases."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rr14Vl9GR4zq"
      },
      "source": [
        "### Dataset sharding\n",
        "\n",
        "In multi-worker training, dataset sharding is needed to ensure convergence and performance.\n",
        "\n",
        "The example in the previous section relies on the default autosharding provided by the `tf.distribute.Strategy` API. You can control the sharding by setting the `tf.data.experimental.AutoShardPolicy` of the `tf.data.experimental.DistributeOptions`. To learn more about auto-sharding see the [Distributed input guide](https://www.tensorflow.org/tutorials/distribute/input#sharding).\n",
        "\n",
        "Here is a quick example of how to turn OFF the auto sharding, so each replica processes every example (not recommended):\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JxEtdh1vH-TF"
      },
      "outputs": [],
      "source": [
        "options = tf.data.Options()\n",
        "options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF\n",
        "\n",
        "global_batch_size = 64\n",
        "multi_worker_dataset = mnist.mnist_dataset(batch_size=64)\n",
        "dataset_no_auto_shard = multi_worker_dataset.with_options(options)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gmqvlh5LhAoU"
      },
      "source": [
        "### Evaluation\n",
        "\n",
        "If you pass `validation_data` into `model.fit`, it will alternate between training and evaluation for each epoch. The evaluation taking `validation_data` is distributed across the same set of workers and the evaluation results are aggregated and available for all workers. Similar to training, the validation dataset is automatically sharded at the file level. You need to set a global batch size in the validation dataset and set `validation_steps`. A repeated dataset is also recommended for evaluation.\n",
        "\n",
        "Alternatively, you can also create another task that periodically reads checkpoints and runs the evaluation. This is what Estimator does. But this is not a recommended way to perform evaluation and thus its details are omitted."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XVk4ftYx6JAO"
      },
      "source": [
        "### Performance\n",
        "\n",
        "You now have a Keras model that is all set up to run in multiple workers with `MultiWorkerMirroredStrategy`.  You can try the following techniques to tweak performance of multi-worker training with `MultiWorkerMirroredStrategy`.\n",
        "\n",
        "*   `MultiWorkerMirroredStrategy` provides multiple [collective communication implementations](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CommunicationImplementation).  `RING` implements ring-based collectives using gRPC as the cross-host communication layer.  `NCCL` uses [Nvidia's NCCL](https://developer.nvidia.com/nccl) to implement collectives.  `AUTO` defers the choice to the runtime.  The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster.  To override the automatic choice, specify `communication_options` parameter of `MultiWorkerMirroredStrategy`'s constructor, e.g. `communication_options=tf.distribute.experimental.CommunicationOptions(implementation=tf.distribute.experimental.CollectiveCommunication.NCCL)`.\n",
        "*    Cast the variables to `tf.float` if possible.  The official ResNet model includes [an example](https://github.com/tensorflow/models/blob/8367cf6dabe11adf7628541706b660821f397dce/official/resnet/resnet_model.py#L466) of how this can be done.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "97WhAu8uKw3j"
      },
      "source": [
        "### Fault tolerance\n",
        "\n",
        "In synchronous training, the cluster would fail if one of the workers fails and no failure-recovery mechanism exists. Using Keras with `tf.distribute.Strategy` comes with the advantage of fault tolerance in cases where workers die or are otherwise unstable. You do this by preserving training state in the distributed file system of your choice, such that upon restart of the instance that previously failed or preempted, the training state is recovered.\n",
        "\n",
        "When a worker becomes unavailable, other workers will fail (possibly after a timeout). In such cases, the unavailable worker needs to be restarted, as well as other workers that have failed.\n",
        "\n",
        "Note:\n",
        "Previously, the `ModelCheckpoint` callback provided a mechanism to restore training state upon restart from job failure for multi-worker training. The TensorFlow team are introducing a new [`BackupAndRestore`](#scrollTo=kmH8uCUhfn4w) callback, to also add the support to single worker training for a consistent experience, and removed fault tolerance functionality from existing `ModelCheckpoint` callback. From now on, applications that rely on this behavior should migrate to the new callback."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KvHPjGlyyFt6"
      },
      "source": [
        "#### ModelCheckpoint callback\n",
        "\n",
        "`ModelCheckpoint` callback no longer provides fault tolerance functionality, please use [`BackupAndRestore`](#scrollTo=kmH8uCUhfn4w) callback instead.\n",
        "\n",
        "The `ModelCheckpoint` callback can still be used to save checkpoints. But with this, if training was interrupted or successfully finished, in order to continue training from the checkpoint, the user is responsible to load the model manually.\n",
        "\n",
        "Optionally the user can choose to save and restore model/weights outside `ModelCheckpoint` callback."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EUNV5Utc1d0s"
      },
      "source": [
        "### Model saving and loading\n",
        "\n",
        "To save your model using `model.save` or `tf.saved_model.save`, the destination for saving needs to be different for each worker. On the non-chief workers, you will need to save the model to a temporary directory, and on the chief, you will need to save to the provided model directory. The temporary directories on the worker need to be unique to prevent errors resulting from multiple workers trying to write to the same location. The model saved in all the directories are identical and typically only the model saved by the chief should be referenced for restoring or serving. You should have some cleanup logic that deletes the temporary directories created by the workers once your training has completed.\n",
        "\n",
        "The reason you need to save on the chief and workers at the same time is because you might be aggregating variables during checkpointing which requires both the chief and workers to participate in the allreduce communication protocol. On the other hand, letting chief and workers save to the same model directory will result in errors due to contention.\n",
        "\n",
        "With `MultiWorkerMirroredStrategy`, the program is run on every worker, and in order to know whether the current worker is chief, it takes advantage of the cluster resolver object that has attributes `task_type` and `task_id`. `task_type` tells you what the current job is (e.g. 'worker'), and `task_id` tells you the identifier of the worker. The worker with id 0 is designated as the chief worker.\n",
        "\n",
        "In the code snippet below, `write_filepath` provides the file path to write, which depends on the worker id. In the case of chief (worker with id 0), it writes to the original file path; for others, it creates a temporary directory (with id in the directory path) to write in:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XQfGkmg-pfCY"
      },
      "outputs": [],
      "source": [
        "model_path = '/tmp/keras-model'\n",
        "\n",
        "def _is_chief(task_type, task_id):\n",
        "  # Note: there are two possible `TF_CONFIG` configuration.\n",
        "  #   1) In addition to `worker` tasks, a `chief` task type is use;\n",
        "  #      in this case, this function should be modified to \n",
        "  #      `return task_type == 'chief'`.\n",
        "  #   2) Only `worker` task type is used; in this case, worker 0 is\n",
        "  #      regarded as the chief. The implementation demonstrated here\n",
        "  #      is for this case.\n",
        "  # For the purpose of this colab section, we also add `task_type is None` \n",
        "  # case because it is effectively run with only single worker.\n",
        "  return (task_type == 'worker' and task_id == 0) or task_type is None\n",
        "\n",
        "def _get_temp_dir(dirpath, task_id):\n",
        "  base_dirpath = 'workertemp_' + str(task_id)\n",
        "  temp_dir = os.path.join(dirpath, base_dirpath)\n",
        "  tf.io.gfile.makedirs(temp_dir)\n",
        "  return temp_dir\n",
        "\n",
        "def write_filepath(filepath, task_type, task_id):\n",
        "  dirpath = os.path.dirname(filepath)\n",
        "  base = os.path.basename(filepath)\n",
        "  if not _is_chief(task_type, task_id):\n",
        "    dirpath = _get_temp_dir(dirpath, task_id)\n",
        "  return os.path.join(dirpath, base)\n",
        "\n",
        "task_type, task_id = (strategy.cluster_resolver.task_type,\n",
        "                      strategy.cluster_resolver.task_id)\n",
        "write_model_path = write_filepath(model_path, task_type, task_id)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hs0_agYR_qKm"
      },
      "source": [
        "With that, you're now ready to save:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J-yA3BYG_vTs"
      },
      "outputs": [],
      "source": [
        "multi_worker_model.save(write_model_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8LXUVVl9_v5x"
      },
      "source": [
        "As described above, later on the model should only be loaded from the path chief saved to, so let's remove the temporary ones the non-chief workers saved:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aJTyu-97ABpY"
      },
      "outputs": [],
      "source": [
        "if not _is_chief(task_type, task_id):\n",
        "  tf.io.gfile.rmtree(os.path.dirname(write_model_path))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nr-2PKlHAPBT"
      },
      "source": [
        "Now, when it's time to load, let's use convenient `tf.keras.models.load_model` API, and continue with further work. Here, assume only using single worker to load and continue training, in which case you do not call `tf.keras.models.load_model` within another `strategy.scope()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iUZna-JKAOrX"
      },
      "outputs": [],
      "source": [
        "loaded_model = tf.keras.models.load_model(model_path)\n",
        "\n",
        "# Now that the model is restored, and can continue with the training.\n",
        "loaded_model.fit(single_worker_dataset, epochs=2, steps_per_epoch=20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YJ1fmxmTpocS"
      },
      "source": [
        "### Checkpoint saving and restoring\n",
        "\n",
        "On the other hand, checkpointing allows you to save model's weights and restore them without having to save the whole model. Here, you'll create one `tf.train.Checkpoint` that tracks the model, which is managed by a `tf.train.CheckpointManager` so that only the latest checkpoint is preserved. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_1-RYaB5xnNH"
      },
      "outputs": [],
      "source": [
        "checkpoint_dir = '/tmp/ckpt'\n",
        "\n",
        "checkpoint = tf.train.Checkpoint(model=multi_worker_model)\n",
        "write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)\n",
        "checkpoint_manager = tf.train.CheckpointManager(\n",
        "    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7oBpPCRsW1MF"
      },
      "source": [
        "Once the `CheckpointManager` is set up, you're now ready to save, and remove the checkpoints non-chief workers saved."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l1ZXG_GbWzLp"
      },
      "outputs": [],
      "source": [
        "checkpoint_manager.save()\n",
        "if not _is_chief(task_type, task_id):\n",
        "  tf.io.gfile.rmtree(write_checkpoint_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RO7cbN40XD5v"
      },
      "source": [
        "Now, when you need to restore, you can find the latest checkpoint saved using the convenient `tf.train.latest_checkpoint` function. After restoring the checkpoint, you can continue with training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NJW7vtknXFEH"
      },
      "outputs": [],
      "source": [
        "latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n",
        "checkpoint.restore(latest_checkpoint)\n",
        "multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kmH8uCUhfn4w"
      },
      "source": [
        "#### BackupAndRestore callback\n",
        "\n",
        "BackupAndRestore callback provides fault tolerance functionality, by backing up the model and current epoch number in a temporary checkpoint file under `backup_dir` argument to `BackupAndRestore`. This is done at the end of each epoch.\n",
        "\n",
        "Once jobs get interrupted and restart, the callback restores the last checkpoint, and training continues from the beginning of the interrupted epoch. Any partial training already done in the unfinished epoch before interruption will be thrown away, so that it doesn't affect the final model state.\n",
        "\n",
        "To use it, provide an instance of `tf.keras.callbacks.experimental.BackupAndRestore` at the `tf.keras.Model.fit()` call.\n",
        "\n",
        "With MultiWorkerMirroredStrategy, if a worker gets interrupted, the whole cluster pauses until the interrupted worker is restarted. Other workers will also restart, and the interrupted worker rejoins the cluster. Then, every worker reads the checkpoint file that was previously saved and picks up its former state, thereby allowing the cluster to get back in sync. Then the training continues.\n",
        "\n",
        "`BackupAndRestore` callback uses `CheckpointManager` to save and restore the training state, which generates a file called checkpoint that tracks existing checkpoints together with the latest one. For this reason, `backup_dir` should not be re-used to store other checkpoints in order to avoid name collision.\n",
        "\n",
        "Currently, `BackupAndRestore` callback supports single worker with no strategy, MirroredStrategy, and multi-worker with MultiWorkerMirroredStrategy.\n",
        "Below are two examples for both multi-worker training and single worker training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CYdzZi4Qs1jz"
      },
      "outputs": [],
      "source": [
        "# Multi-worker training with MultiWorkerMirroredStrategy.\n",
        "\n",
        "callbacks = [tf.keras.callbacks.experimental.BackupAndRestore(backup_dir='/tmp/backup')]\n",
        "with strategy.scope():\n",
        "  multi_worker_model = mnist.build_and_compile_cnn_model()\n",
        "multi_worker_model.fit(multi_worker_dataset,\n",
        "                       epochs=3,\n",
        "                       steps_per_epoch=70,\n",
        "                       callbacks=callbacks)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rIV5_3ebzXmB"
      },
      "source": [
        "If you inspect the directory of `backup_dir` you specified in `BackupAndRestore`, you may notice some temporarily generated checkpoint files. Those files are needed for recovering the previously lost instances, and they will be removed by the library at the end of `tf.keras.Model.fit()` upon successful exiting of your training.\n",
        "\n",
        "Note: Currently BackupAndRestore only supports eager mode. In graph mode, consider using [Save/Restore Model](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading) mentioned above, and by providing `initial_epoch` in `model.fit()`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ega2hdOQEmy_"
      },
      "source": [
        "## See also\n",
        "1. [Distributed Training in TensorFlow](https://www.tensorflow.org/guide/distributed_training) guide provides an overview of the available distribution strategies.\n",
        "2. [Official models](https://github.com/tensorflow/models/tree/master/official), many of which can be configured to run multiple distribution strategies.\n",
        "3. The [Performance section](../../guide/function.ipynb) in the guide provides information about other strategies and [tools](../../guide/profiler.md) you can use to optimize the performance of your TensorFlow models.\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "multi_worker_with_keras.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
